{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HyperParameter Tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### `keras.wrappers.scikit_learn`\n",
    "\n",
    "Example adapted from: [https://github.com/fchollet/keras/blob/master/examples/mnist_sklearn_wrapper.py]()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Problem: \n",
    "Builds simple CNN models on MNIST and uses sklearn's GridSearchCV to find best model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.random.seed(1337)  # for reproducibility"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.datasets import mnist\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Activation, Flatten\n",
    "from keras.layers import Conv2D, MaxPooling2D\n",
    "from keras.utils import np_utils\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras import backend as K"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "nb_classes = 10\n",
    "\n",
    "# input image dimensions\n",
    "img_rows, img_cols = 28, 28"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# load training data and do basic data normalization\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "\n",
    "if K.image_dim_ordering() == 'th':\n",
    "    X_train = X_train.reshape(X_train.shape[0], 1, img_rows, img_cols)\n",
    "    X_test = X_test.reshape(X_test.shape[0], 1, img_rows, img_cols)\n",
    "    input_shape = (1, img_rows, img_cols)\n",
    "else:\n",
    "    X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)\n",
    "    X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)\n",
    "    input_shape = (img_rows, img_cols, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X_train = X_train.astype('float32')\n",
    "X_test = X_test.astype('float32')\n",
    "X_train /= 255\n",
    "X_test /= 255\n",
    "\n",
    "# convert class vectors to binary class matrices\n",
    "y_train = np_utils.to_categorical(y_train, nb_classes)\n",
    "y_test = np_utils.to_categorical(y_test, nb_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def make_model(dense_layer_sizes, filters, kernel_size, pool_size):\n",
    "    '''Creates model comprised of 2 convolutional layers followed by dense layers\n",
    "\n",
    "    dense_layer_sizes: List of layer sizes. This list has one number for each layer\n",
    "    nb_filters: Number of convolutional filters in each convolutional layer\n",
    "    nb_conv: Convolutional kernel size\n",
    "    nb_pool: Size of pooling area for max pooling\n",
    "    '''\n",
    "\n",
    "    model = Sequential()\n",
    "\n",
    "    model.add(Conv2D(filters, (kernel_size, kernel_size),\n",
    "                     padding='valid', input_shape=input_shape))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(Conv2D(filters, (kernel_size, kernel_size)))\n",
    "    model.add(Activation('relu'))\n",
    "    model.add(MaxPooling2D(pool_size=(pool_size, pool_size)))\n",
    "    model.add(Dropout(0.25))\n",
    "\n",
    "    model.add(Flatten())\n",
    "    for layer_size in dense_layer_sizes:\n",
    "        model.add(Dense(layer_size))\n",
    "        model.add(Activation('relu'))\n",
    "    model.add(Dropout(0.5))\n",
    "    model.add(Dense(nb_classes))\n",
    "    model.add(Activation('softmax'))\n",
    "\n",
    "    model.compile(loss='categorical_crossentropy',\n",
    "                  optimizer='adadelta',\n",
    "                  metrics=['accuracy'])\n",
    "\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "dense_size_candidates = [[32], [64], [32, 32], [64, 64]]\n",
    "my_classifier = KerasClassifier(make_model, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## GridSearch HyperParameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/3\n",
      "40000/40000 [==============================] - ETA: 0s - loss: 0.8971 - acc: 0.694 - 10s - loss: 0.8961 - acc: 0.6953    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 9s - loss: 0.5362 - acc: 0.8299     \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.4425 - acc: 0.8594    \n",
      "39552/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.7593 - acc: 0.7543    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.4489 - acc: 0.8597    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.3841 - acc: 0.8814    \n",
      "39648/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.9089 - acc: 0.6946    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 9s - loss: 0.5560 - acc: 0.8228     \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.4597 - acc: 0.8556    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.8415 - acc: 0.7162    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.4929 - acc: 0.8423    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 9s - loss: 0.4172 - acc: 0.8703     \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3819 - acc: 0.8812    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3491 - acc: 0.8919    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3284 - acc: 0.8985    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.7950 - acc: 0.7349    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.4913 - acc: 0.8428    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.4081 - acc: 0.8709    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3613 - acc: 0.8870    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3293 - acc: 0.8968    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3024 - acc: 0.9058    \n",
      "39936/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.9822 - acc: 0.6735    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.6270 - acc: 0.8009    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 9s - loss: 0.5045 - acc: 0.8409     \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.4396 - acc: 0.8599    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3978 - acc: 0.8775    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3605 - acc: 0.8871    \n",
      "39872/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.6851 - acc: 0.7777    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.3989 - acc: 0.8776    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.3225 - acc: 0.9021    \n",
      "39552/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.5846 - acc: 0.8164    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.3243 - acc: 0.9053    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.2697 - acc: 0.9213    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.6339 - acc: 0.8017    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.3417 - acc: 0.8975    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.2783 - acc: 0.9184    \n",
      "39648/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.6652 - acc: 0.7854    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3693 - acc: 0.8911    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2923 - acc: 0.9130    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2479 - acc: 0.9274    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2176 - acc: 0.9360    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.1994 - acc: 0.9416    \n",
      "39616/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.6463 - acc: 0.7952    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3648 - acc: 0.8898    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2880 - acc: 0.9154    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2497 - acc: 0.9249    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2154 - acc: 0.9357    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.1946 - acc: 0.9417    \n",
      "39584/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.6212 - acc: 0.8012    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3341 - acc: 0.9008    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2706 - acc: 0.9195    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2343 - acc: 0.9307    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.2109 - acc: 0.9383    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.1961 - acc: 0.9420    \n",
      "39648/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 12s - loss: 0.9322 - acc: 0.6835    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.5578 - acc: 0.8202    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.4651 - acc: 0.8518    \n",
      "40000/40000 [==============================] - 4s     \n",
      "Epoch 1/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.7615 - acc: 0.7467    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.4369 - acc: 0.8634    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 10s - loss: 0.3646 - acc: 0.8865    \n",
      "39904/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 12s - loss: 0.7744 - acc: 0.7471    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.4294 - acc: 0.8674    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.3620 - acc: 0.8873    \n",
      "39968/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 12s - loss: 0.8007 - acc: 0.7354    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.4769 - acc: 0.8499    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.4020 - acc: 0.8743    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3551 - acc: 0.8905    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3256 - acc: 0.8993    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3005 - acc: 0.9067    \n",
      "39520/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 12s - loss: 0.8505 - acc: 0.7123    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.5156 - acc: 0.8321    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.4208 - acc: 0.8660    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3614 - acc: 0.8854    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3258 - acc: 0.8980    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3044 - acc: 0.9046    \n",
      "39936/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 12s - loss: 0.7670 - acc: 0.7494    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.4593 - acc: 0.8574    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - ETA: 0s - loss: 0.3896 - acc: 0.880 - 11s - loss: 0.3898 - acc: 0.8799    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3514 - acc: 0.8907    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 10s - loss: 0.3124 - acc: 0.9020    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.2981 - acc: 0.9097    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 12s - loss: 0.5547 - acc: 0.8239    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.2752 - acc: 0.9204    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.2183 - acc: 0.9359    \n",
      "39520/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 12s - loss: 0.5718 - acc: 0.8172    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.3141 - acc: 0.9054    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.2536 - acc: 0.9247    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/3\n",
      "40000/40000 [==============================] - 12s - loss: 0.5111 - acc: 0.8399    \n",
      "Epoch 2/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.2469 - acc: 0.9270    \n",
      "Epoch 3/3\n",
      "40000/40000 [==============================] - 11s - loss: 0.1992 - acc: 0.9422    \n",
      "20000/20000 [==============================] - 2s     \n",
      "40000/40000 [==============================] - 4s     \n",
      "Epoch 1/6\n",
      "40000/40000 [==============================] - 12s - loss: 0.6041 - acc: 0.8066    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.2951 - acc: 0.9132    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.2343 - acc: 0.9315    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1995 - acc: 0.9418    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1779 - acc: 0.9487    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1612 - acc: 0.9540    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 12s - loss: 0.6137 - acc: 0.8069    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.3075 - acc: 0.9096    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.2309 - acc: 0.9325    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1935 - acc: 0.9443    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1679 - acc: 0.9518    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1576 - acc: 0.9551    \n",
      "39680/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "40000/40000 [==============================] - 12s - loss: 0.5143 - acc: 0.8400    \n",
      "Epoch 2/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.2743 - acc: 0.9205    \n",
      "Epoch 3/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.2248 - acc: 0.9350    \n",
      "Epoch 4/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1964 - acc: 0.9428    \n",
      "Epoch 5/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1736 - acc: 0.9496    \n",
      "Epoch 6/6\n",
      "40000/40000 [==============================] - 11s - loss: 0.1643 - acc: 0.9521    \n",
      "39840/40000 [============================>.] - ETA: 0sEpoch 1/6\n",
      "60000/60000 [==============================] - 18s - loss: 0.4674 - acc: 0.8567    \n",
      "Epoch 2/6\n",
      "60000/60000 [==============================] - 16s - loss: 0.2417 - acc: 0.9293    \n",
      "Epoch 3/6\n",
      "60000/60000 [==============================] - 16s - loss: 0.1966 - acc: 0.9428    \n",
      "Epoch 4/6\n",
      "60000/60000 [==============================] - 17s - loss: 0.1695 - acc: 0.9519    \n",
      "Epoch 5/6\n",
      "60000/60000 [==============================] - 16s - loss: 0.1504 - acc: 0.9571    \n",
      "Epoch 6/6\n",
      "60000/60000 [==============================] - 15s - loss: 0.1393 - acc: 0.9597    \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=None, error_score='raise',\n",
       "       estimator=<keras.wrappers.scikit_learn.KerasClassifier object at 0x7f434a86ce48>,\n",
       "       fit_params={}, iid=True, n_jobs=1,\n",
       "       param_grid={'filters': [8], 'pool_size': [2], 'epochs': [3, 6], 'dense_layer_sizes': [[32], [64], [32, 32], [64, 64]], 'kernel_size': [3]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score=True,\n",
       "       scoring='neg_log_loss', verbose=0)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "validator = GridSearchCV(my_classifier,\n",
    "                         param_grid={'dense_layer_sizes': dense_size_candidates,\n",
    "                                     # nb_epoch is avail for tuning even when not\n",
    "                                     # an argument to model building function\n",
    "                                     'epochs': [3, 6],\n",
    "                                     'filters': [8],\n",
    "                                     'kernel_size': [3],\n",
    "                                     'pool_size': [2]},\n",
    "                         scoring='neg_log_loss',\n",
    "                         n_jobs=1)\n",
    "validator.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The parameters of the best model are: \n",
      "{'filters': 8, 'pool_size': 2, 'epochs': 6, 'dense_layer_sizes': [64, 64], 'kernel_size': 3}\n",
      " 9920/10000 [============================>.] - ETA: 0sloss :  0.0577878101223\n",
      "acc :  0.9822\n"
     ]
    }
   ],
   "source": [
    "print('The parameters of the best model are: ')\n",
    "print(validator.best_params_)\n",
    "\n",
    "# validator.best_estimator_ returns sklearn-wrapped version of best model.\n",
    "# validator.best_estimator_.model returns the (unwrapped) keras model\n",
    "best_model = validator.best_estimator_.model\n",
    "metric_names = best_model.metrics_names\n",
    "metric_values = best_model.evaluate(X_test, y_test)\n",
    "for metric, value in zip(metric_names, metric_values):\n",
    "    print(metric, ': ', value)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# There's more:\n",
    "\n",
    "The `GridSearchCV` model in scikit-learn performs a complete search, considering **all** the possible combinations of Hyper-parameters we want to optimise.\n",
    "\n",
    "If we want to apply for an optmised and bounded search in the hyper-parameter space, I strongly suggest to take a look at:\n",
    "\n",
    "* `Keras + hyperopt == hyperas`: [http://maxpumperla.github.io/hyperas/](http://maxpumperla.github.io/hyperas/)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
